{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wJcYs_ERTnnI"
      },
      "source": [
        "##### Copyright 2021 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "HMUDt0CiUJk9"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "77z2OchJTk0l"
      },
      "source": [
        "# Training with assisted logic\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/guide/migrate/sessionrunhook_callback\">\n",
        "    <img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />\n",
        "    View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/migrate/sessionrunhook_callback.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />\n",
        "    Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/guide/migrate/sessionrunhook_callback.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />\n",
        "    View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/guide/migrate/sessionrunhook_callback.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gzPg7zOP0Dzy"
      },
      "source": [
        "## Migrating a TF1 custom SessionRunHook to TF2 Keras Callback"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KZHPY55aFyXT"
      },
      "source": [
        "In TF1, a `tf.estimator.SessionRunHook` runs custom logic while training your model with `tf.estimator.Estimator`, while in TF2, a `tf.keras.callbacks.Callback` runs custom logic within Keras `Model.fit` for training. Examples of common tasks performed by such custom logic include checkpoint saving and TensorBoard summary writing. This example demonstrates equivalency between the two in TF1 and TF2 while measuring examples per second for each step, with a custom `SessionRunHook` and a custom `Callback`, respectively."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "296d8b0DoKpV"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import tensorflow.compat.v1 as tf1"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xVGYtUXyXNuE"
      },
      "outputs": [],
      "source": [
        "features = [[1., 1.5], [2., 2.5], [3., 3.5]]\n",
        "labels = [[0.3], [0.5], [0.7]]\n",
        "eval_features = [[4., 4.5], [5., 5.5], [6., 6.5]]\n",
        "eval_labels = [[0.8], [0.9], [1.]]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ON4zQifT0Vec"
      },
      "source": [
        "### TF1: Estimator with a SessionRunHook\n",
        "The following code utilizes TF1 to demonstrate an estimator's custom SessionRunHook measuring examples per second during training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "S-myEclbXUL7"
      },
      "outputs": [],
      "source": [
        "def _input_fn():\n",
        "  return tf1.data.Dataset.from_tensor_slices(\n",
        "      (features, labels)).batch(1).repeat(100)\n",
        "\n",
        "def _model_fn(features, labels, mode):\n",
        "  logits = tf1.layers.Dense(1)(features)\n",
        "  loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)\n",
        "  optimizer = tf1.train.AdagradOptimizer(0.05)\n",
        "  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())\n",
        "  return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Xd9sPTkO0ZTD"
      },
      "outputs": [],
      "source": [
        "import time\n",
        "from datetime import datetime\n",
        "from absl import flags\n",
        "\n",
        "class LoggerHook(tf1.train.SessionRunHook):\n",
        "  \"\"\"Logs loss and runtime.\"\"\"\n",
        "\n",
        "  def begin(self):\n",
        "    self._step = -1\n",
        "    self._start_time = time.time()\n",
        "    self.log_frequency = 10\n",
        "\n",
        "  def before_run(self, run_context):\n",
        "    self._step += 1\n",
        "\n",
        "  def after_run(self, run_context, run_values):\n",
        "    if self._step % self.log_frequency == 0:\n",
        "      current_time = time.time()\n",
        "      duration = current_time - self._start_time\n",
        "      self._start_time = current_time\n",
        "      examples_per_sec = self.log_frequency / duration\n",
        "      print('Time:', datetime.now(), ', Step #:', self._step,\n",
        "            ', Examples per second:', examples_per_sec)\n",
        "\n",
        "estimator = tf1.estimator.Estimator(model_fn=_model_fn)\n",
        "estimator.train(_input_fn, hooks = [LoggerHook()])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3uZCDMrM2CEg"
      },
      "source": [
        "### TF2: Keras model with callback\n",
        "You can use a custom `tf.keras.callbacks.Callback` to add a similar collection of metrics to a Keras model.  In this case, it will also measure examples per second which should be comparable to the metrics in TF1."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UbMPoiB92KRG"
      },
      "outputs": [],
      "source": [
        "class CustomCallback(tf.keras.callbacks.Callback):\n",
        "\n",
        "    def on_train_begin(self, logs = None):\n",
        "      self._step = -1\n",
        "      self._start_time = time.time()\n",
        "      self.log_frequency = 10\n",
        "\n",
        "    def on_train_batch_begin(self, batch, logs = None):\n",
        "      self._step += 1\n",
        "\n",
        "    def on_train_batch_end(self, batch, logs = None):\n",
        "      if self._step % self.log_frequency == 0:\n",
        "        current_time = time.time()\n",
        "        duration = current_time - self._start_time\n",
        "        self._start_time = current_time\n",
        "        examples_per_sec = self.log_frequency / duration\n",
        "        print('Time:', datetime.now(), ', Step #:', self._step,\n",
        "              ', Examples per second:', examples_per_sec)\n",
        "\n",
        "callback = CustomCallback()\n",
        "\n",
        "dataset = tf.data.Dataset.from_tensor_slices(\n",
        "    (features, labels)).batch(1).repeat(100)\n",
        "\n",
        "model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])\n",
        "optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)\n",
        "\n",
        "model.compile(optimizer, \"mse\")\n",
        "result = model.fit(dataset, callbacks=[callback], verbose = 0)\n",
        "result.history # `result.history` gives the results of training metrics"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EFqFi21Ftskq"
      },
      "source": [
        "#Next Steps\n",
        "\n",
        "*   For more information on how to write your own `tf.keras.callbacks.Callback` check out https://www.tensorflow.org/guide/keras/custom_callback\n",
        "*   For more information on how to write your own `tf.estimator.SessionRunHook` check out https://www.tensorflow.org/api_docs/python/tf/estimator/SessionRunHook"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "sessionrunhook_callback.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
